{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "keras_bert_classification_tpu.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "doNFRjPqiBhM",
        "colab_type": "code",
        "outputId": "725008ed-61f8-433c-82ce-9293bc07629c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 289
        }
      },
      "source": [
        "# @title Preparation\n",
        "!pip install -q keras-bert keras-rectified-adam\n",
        "!wget -q https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip\n",
        "!unzip -o uncased_L-12_H-768_A-12.zip"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "  Building wheel for keras-bert (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for keras-rectified-adam (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for keras-transformer (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for keras-pos-embd (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for keras-multi-head (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for keras-layer-normalization (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for keras-position-wise-feed-forward (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for keras-embed-sim (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for keras-self-attention (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Archive:  uncased_L-12_H-768_A-12.zip\n",
            "   creating: uncased_L-12_H-768_A-12/\n",
            "  inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.meta  \n",
            "  inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001  \n",
            "  inflating: uncased_L-12_H-768_A-12/vocab.txt  \n",
            "  inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.index  \n",
            "  inflating: uncased_L-12_H-768_A-12/bert_config.json  \n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bzoFRUGmh6a3",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# @title Constants\n",
        "\n",
        "SEQ_LEN = 128\n",
        "BATCH_SIZE = 128\n",
        "EPOCHS = 5\n",
        "LR = 1e-4"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KUQ8UtquieFj",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# @title Environment\n",
        "import os\n",
        "\n",
        "pretrained_path = 'uncased_L-12_H-768_A-12'\n",
        "config_path = os.path.join(pretrained_path, 'bert_config.json')\n",
        "checkpoint_path = os.path.join(pretrained_path, 'bert_model.ckpt')\n",
        "vocab_path = os.path.join(pretrained_path, 'vocab.txt')\n",
        "\n",
        "# TF_KERAS must be added to environment variables in order to use TPU\n",
        "os.environ['TF_KERAS'] = '1'"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gGzsxkLTpRrs",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 170
        },
        "outputId": "6bc483bd-fee3-4cb1-c317-3ff6983389f2"
      },
      "source": [
        "# @title Initialize TPU Strategy\n",
        "\n",
        "import tensorflow as tf\n",
        "from keras_bert import get_custom_objects\n",
        "\n",
        "TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']\n",
        "resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)\n",
        "tf.contrib.distribute.initialize_tpu_system(resolver)\n",
        "strategy = tf.contrib.distribute.TPUStrategy(resolver)"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "WARNING: Logging before flag parsing goes to stderr.\n",
            "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
            "For more information, please see:\n",
            "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
            "  * https://github.com/tensorflow/addons\n",
            "  * https://github.com/tensorflow/io (for I/O related ops)\n",
            "If you depend on functionality not listed there, please file an issue.\n",
            "\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sVTPNxOyj4HJ",
        "colab_type": "code",
        "outputId": "a6b0bc85-93cb-4642-ea81-77d78af6a186",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 139
        }
      },
      "source": [
        "# @title Load Basic Model\n",
        "import codecs\n",
        "from keras_bert import load_trained_model_from_checkpoint\n",
        "\n",
        "token_dict = {}\n",
        "with codecs.open(vocab_path, 'r', 'utf8') as reader:\n",
        "    for line in reader:\n",
        "        token = line.strip()\n",
        "        token_dict[token] = len(token_dict)\n",
        "\n",
        "with strategy.scope():\n",
        "    model = load_trained_model_from_checkpoint(\n",
        "        config_path,\n",
        "        checkpoint_path,\n",
        "        training=True,\n",
        "        trainable=True,\n",
        "        seq_len=SEQ_LEN,\n",
        "    )"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xioN-O_vtztC",
        "colab_type": "code",
        "outputId": "05526ce5-d0fe-4ede-fe4a-199997244f27",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        }
      },
      "source": [
        "# @title Download IMDB Data\n",
        "import tensorflow as tf\n",
        "\n",
        "dataset = tf.keras.utils.get_file(\n",
        "    fname=\"aclImdb.tar.gz\", \n",
        "    origin=\"http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\", \n",
        "    extract=True,\n",
        ")"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading data from http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\n",
            "84131840/84125825 [==============================] - 3s 0us/step\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xfC3Nh8pnckd",
        "colab_type": "code",
        "outputId": "00f6e5d4-f6ca-4626-9b3b-9525efbf4596",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 85
        }
      },
      "source": [
        "# @title Convert Data to Array\n",
        "import os\n",
        "import numpy as np\n",
        "from tqdm import tqdm\n",
        "from keras_bert import Tokenizer\n",
        "\n",
        "tokenizer = Tokenizer(token_dict)\n",
        "\n",
        "\n",
        "def load_data(path):\n",
        "    global tokenizer\n",
        "    indices, sentiments = [], []\n",
        "    for folder, sentiment in (('neg', 0), ('pos', 1)):\n",
        "        folder = os.path.join(path, folder)\n",
        "        for name in tqdm(os.listdir(folder)):\n",
        "            with open(os.path.join(folder, name), 'r') as reader:\n",
        "                  text = reader.read()\n",
        "            ids, segments = tokenizer.encode(text, max_len=SEQ_LEN)\n",
        "            indices.append(ids)\n",
        "            sentiments.append(sentiment)\n",
        "    items = list(zip(indices, sentiments))\n",
        "    np.random.shuffle(items)\n",
        "    indices, sentiments = zip(*items)\n",
        "    indices = np.array(indices)\n",
        "    mod = indices.shape[0] % BATCH_SIZE\n",
        "    if mod > 0:\n",
        "        indices, sentiments = indices[:-mod], sentiments[:-mod]\n",
        "    return [indices, np.zeros_like(indices)], np.array(sentiments)\n",
        "  \n",
        "  \n",
        "train_path = os.path.join(os.path.dirname(dataset), 'aclImdb', 'train')\n",
        "test_path = os.path.join(os.path.dirname(dataset), 'aclImdb', 'test')\n",
        "\n",
        "\n",
        "train_x, train_y = load_data(train_path)\n",
        "test_x, test_y = load_data(test_path)"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "100%|██████████| 12500/12500 [00:44<00:00, 277.81it/s]\n",
            "100%|██████████| 12500/12500 [00:46<00:00, 268.06it/s]\n",
            "100%|██████████| 12500/12500 [00:44<00:00, 282.12it/s]\n",
            "100%|██████████| 12500/12500 [00:45<00:00, 277.77it/s]\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OhMA1j7wnqSm",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# @title Build Custom Model\n",
        "from tensorflow.python import keras\n",
        "from keras_radam import RAdam\n",
        "\n",
        "with strategy.scope():\n",
        "    inputs = model.inputs[:2]\n",
        "    dense = model.get_layer('NSP-Dense').output\n",
        "    outputs = keras.layers.Dense(units=2, activation='softmax')(dense)\n",
        "    \n",
        "    model = keras.models.Model(inputs, outputs)\n",
        "    model.compile(\n",
        "        RAdam(lr=LR),\n",
        "        loss='sparse_categorical_crossentropy',\n",
        "        metrics=['sparse_categorical_accuracy'],\n",
        "    )"
      ],
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jmOLb7lWvDvl",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# @title Initialize Variables\n",
        "import tensorflow as tf\n",
        "import tensorflow.keras.backend as K\n",
        "\n",
        "sess = K.get_session()\n",
        "uninitialized_variables = set([i.decode('ascii') for i in sess.run(tf.report_uninitialized_variables())])\n",
        "init_op = tf.variables_initializer(\n",
        "    [v for v in tf.global_variables() if v.name.split(':')[0] in uninitialized_variables]\n",
        ")\n",
        "sess.run(init_op)"
      ],
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QgP7bCQxrZpQ",
        "colab_type": "code",
        "outputId": "d2167bd6-2a03-48bd-e6c6-d31c63c57e65",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 204
        }
      },
      "source": [
        "# @title Fit\n",
        "\n",
        "model.fit(\n",
        "    train_x,\n",
        "    train_y,\n",
        "    epochs=EPOCHS,\n",
        "    batch_size=BATCH_SIZE,\n",
        ")"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch 1/5\n",
            "195/195 [==============================] - 84s 432ms/step - loss: 0.7142 - sparse_categorical_accuracy: 0.5161\n",
            "Epoch 2/5\n",
            "195/195 [==============================] - 38s 196ms/step - loss: 0.6928 - sparse_categorical_accuracy: 0.5467\n",
            "Epoch 3/5\n",
            "195/195 [==============================] - 38s 195ms/step - loss: 0.6061 - sparse_categorical_accuracy: 0.6643\n",
            "Epoch 4/5\n",
            "195/195 [==============================] - 38s 195ms/step - loss: 0.4952 - sparse_categorical_accuracy: 0.7586\n",
            "Epoch 5/5\n",
            "195/195 [==============================] - 38s 195ms/step - loss: 0.4192 - sparse_categorical_accuracy: 0.8067\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<tensorflow.python.keras.callbacks.History at 0x7f65f69d5908>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZBSba3vprlRD",
        "colab_type": "code",
        "outputId": "b1afd7a1-dd12-44c1-cb6e-0631d5896311",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "# @title Predict\n",
        "\n",
        "predicts = model.predict(test_x, verbose=True).argmax(axis=-1)"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "text": [],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Wo1aps8prrCq",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# @title Accuracy\n",
        "\n",
        "print(np.sum(test_y == predicts) / test_y.shape[0])"
      ],
      "execution_count": 13,
      "outputs": []
    }
  ]
}